{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 批量预测"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在对模型进行评估或者一些其他任务时，我们可能希望对一大批数据进行预测，而模型推理又是一个比较耗时的过程，通过循环串行执行比较耗时，而并行执行又需要额外进行开发。\n",
    "\n",
    "SDK 为这一场景提供了批量预测功能，可以并行处理一批数据，简化使用。\n",
    "\n",
    "注意：需要 SDK 版本 >= 0.1.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import qianfan\n",
    "import os\n",
    "\n",
    "# 这里请根据 SDK 文档获取自己的 access key 和 secret key\n",
    "os.environ[\"QIANFAN_ACCESS_KEY\"] = \"your access key\"\n",
    "os.environ[\"QIANFAN_SECRET_KEY\"] = \"your secret key\"\n",
    "# 由于并行会带来较大的并发量，容易触发 QPS 限制导致请求失败\n",
    "# 因此建议这里为 SDK 设置一个合理的 QPS 限制，SDK 会根据该数值进行流控\n",
    "os.environ[\"QIANFAN_QPS_LIMIT\"] = \"3\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "需要准备的数据与需要使用的模型类型相关，例如 `Completion` 接受的参数 `prompt` 类型是 `str`，那么批量预测时需要使用的就是 `List[str]`。\n",
    "\n",
    "这里以 CMMLU 数据集为例，使用 `Completion` 模型进行批量预测。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 这里使用 HuggingFace 上的数据集，如果已经有数据集可以跳过这步\n",
    "!pip install datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datasets\n",
    "dataset = datasets.load_dataset(\"haonan-li/cmmlu\", 'chinese_literature')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "《日出》的结构采用的是\n",
      "以眉间尺为父报仇作为中心线索的小说是\n",
      "“他只是遗憾／他的祖先没有像他一样想过／不然，见到大海的该是他了”。以此作为结尾的诗歌是\n",
      "下列含有小说《伤逝》的鲁迅作品集是\n",
      "汪静之属于\n",
      "杂文《春末闲谈》中，细腰蜂的故事所要表达的是\n",
      "臧克家的第一部诗集是\n",
      "下列小说中运用了“戏剧穿插法”的是\n",
      "台静农的《记波外翁》中波外翁的性格是\n",
      "下列哪一项不属于巴金的短篇小说集\n"
     ]
    }
   ],
   "source": [
    "# 这里仅取前 10 个样本进行测试\n",
    "# data 只需要是 List[str] 类型即可\n",
    "data = dataset['test']['Question'][:10]\n",
    "for i in data:\n",
    "    print(i)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 批量预测"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以将需要预测的所有数据放在一个 List 中，然后调用 `batch_do` 方法，SDK 会在后台启动数个线程并行处理，线程数量由 `worker` 参数指定。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = qianfan.Completion().batch_do(data, worker_num=5)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "返回的结果 `r` 是一个 future 对象，可以调用 `r.result()` 等待所有数据预测完成，并返回结果。\n",
    "\n",
    "或者也可以直接遍历 r，逐个获取结果，在遇到未完成的任务时再等待。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "问题0：\n",
      "《日出》的结构采用的是\n",
      "《日出》的结构采用的是戏剧常见的三一律。这种布局形式要求剧情人物和事件集中单一，情节发生在一天之内，情节结构单一紧凑，戏剧冲突集中在一个高潮，以此突出尖锐的矛盾冲突。同时，剧中没有幕场和舞台指示，时间地点也随着剧中人物生活的场景变化而变化。总的来说，《日出》遵循了戏剧结构的基本原则，使剧情更加紧凑集中。\n",
      "=================\n",
      "问题1：\n",
      "以眉间尺为父报仇作为中心线索的小说是\n",
      "以眉间尺为父报仇作为中心线索的小说是《七侠五义之后记》、《武林奇事》、《乱世枭雄眉间尺为父报仇》等。\n",
      "\n",
      "此外，还有鲁迅的《铸剑》也以眉间尺为父报仇作为中心线索。这些小说讲述了眉间尺在经历了种种磨难和考验后，最终成功为父报仇的故事。\n",
      "=================\n",
      "问题2：\n",
      "“他只是遗憾／他的祖先没有像他一样想过／不然，见到大海的该是他了”。以此作为结尾的诗歌是\n",
      "这段文字似乎是一种诗句的结尾，它可以属于很多诗歌，因为诗歌的语言往往没有特定的含义或特定的主题。以下是我提供的一些可能的诗歌结尾，它们都与这段文字有关：\n",
      "\n",
      "1. “他只是遗憾，他的祖先没有飞翔/否则，他可能已经看见了天堂。”\n",
      "2. “他只是惋惜，他的祖先没有梦想/否则，他们可能已经拥有了自己的海洋。”\n",
      "3. “他望着大海，只有遗憾/如果他的祖先有他的勇气，他会和海鸥一起飞翔。”\n",
      "4. “如果他懂得时间的流转/他也许会笑自己遗憾/毕竟大海一直都在。”\n",
      "\n",
      "以上只是根据你提供的信息进行的一些推测，可能存在偏差或不符。因为具体情境与意象的选择很大程度上取决于个人的创作风格和意图。希望这些想法可以帮到你。\n",
      "=================\n"
     ]
    }
   ],
   "source": [
    "results = r.results()\n",
    "# 这里为了展示结果，仅取了前 3 个结果\n",
    "# 实际使用时可以去掉该限制\n",
    "for i, (input, output) in enumerate(zip(data[:3], results)):\n",
    "    print(f\"问题{i}：\")\n",
    "    print(input)\n",
    "    print(output['result'])\n",
    "    print(\"=================\")\n",
    "\n",
    "# 或者\n",
    "for i, (input, output) in enumerate(zip(data, r)):\n",
    "    print(f\"问题{i}：\")\n",
    "    print(input)\n",
    "    print(output.result()['result'])\n",
    "    print(\"=================\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`r.results()` 会同时返回所有结果，当数据量较大时可能会占用大量内存，可以调用 `r.wait()` 仅等待而不返回结果。\n",
    "\n",
    "此外可以通过 `r.finished_count()` 获取已完成的任务数量，`r.finished_count()` 获取未完成的任务数量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10/10\n"
     ]
    }
   ],
   "source": [
    "# 获取完成的比例\n",
    "print(\"{}/{}\".format(r.finished_count(), r.finished_count()))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对于其他类型模型，数据 `List` 中的元素类型需要进行改变，例如 `ChatCompletion` 接受的参数是 Messages List，那么就要进行转换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_data = [[{\n",
    "    \"role\": \"user\",\n",
    "    \"content\": question\n",
    "}] for question in data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "《日出》的结构采用的是戏剧的“幕”的形式。该剧分为多个幕，每个幕描述一个场景，描述不同人物的行动和心理状态，形成了一个完整的故事结构。这种结构方式使得剧情更加紧凑，也更容易让观众沉浸在故事中。\n"
     ]
    }
   ],
   "source": [
    "r = qianfan.ChatCompletion().batch_do(chat_data, worker_num=5)\n",
    "r.wait()\n",
    "print(r.results()[0]['result'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果推理过程发生错误，会在调用 `result()` 时抛出异常，请注意进行异常处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[ERROR] [11-20 19:27:05] logging.py:86 [t:140418730710784]: api request failed with error code: 336002, err msg: Invalid JSON, please check https://cloud.baidu.com/doc/WENXINWORKSHOP/s/tlmyncueh\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据0发生异常：api return error, code: 336002, msg: Invalid JSON\n",
      "数据1结果：你好，有什么我可以帮助你的吗？\n"
     ]
    }
   ],
   "source": [
    "err_data=[\"你好\", [{\"role\":\"user\", \"content\":\"你好\"}]]\n",
    "\n",
    "r = qianfan.ChatCompletion().batch_do(err_data)\n",
    "\n",
    "for i, res in enumerate(r):\n",
    "    try:\n",
    "        print(\"数据{}结果：{}\".format(i, res.result()['result']))\n",
    "    except Exception as e:\n",
    "        print(f\"数据{i}发生异常：{e}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 异步调用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "《日出》的结构采用的是\n",
      "《日出》的结构采用的是戏剧的“幕”的形式。该剧分为多个幕，每个幕描述一个场景，描述不同人物的行动和心理状态，形成了一个完整的故事结构。这种结构方式使得剧情更加紧凑，也更容易让观众沉浸在故事中。\n",
      "==============\n",
      "以眉间尺为父报仇作为中心线索的小说是\n",
      "以眉间尺为父报仇作为中心线索的小说是《七侠五义番外之三探飞云堡》。这部小说以眉间尺为父报仇为主线，讲述了眉间尺历经磨难，最终成功为父报仇的故事。同时，小说中也涉及到了包公、楚云、白玉堂等主要人物，以及一些江湖恩怨和朝廷政治斗争的情节。\n",
      "==============\n",
      "“他只是遗憾／他的祖先没有像他一样想过／不然，见到大海的该是他了”。以此作为结尾的诗歌是\n",
      "这个结尾的诗歌，结合了你的描述，我推测可能是这样的：\n",
      "\n",
      "“他只是遗憾/他的祖先没有像他一样勇敢/不然，见到大海的该是他了。”\n",
      "\n",
      "这首诗歌可能表达了对祖先没有像他一样有勇气去探索的遗憾，同时也表达了对未来的期待和向往。这样的结尾，既表达了诗人的感慨，也留下了对未来的想象和期待。\n",
      "==============\n"
     ]
    }
   ],
   "source": [
    "results = await qianfan.Completion().abatch_do(data, worker_num=5)\n",
    "# 返回值为一个 List，与输入列表中的元素一一对应\n",
    "# 正常情况下与 `ado` 返回类型一致，但如果发生异常则会是一个 Exception 对象\n",
    "for prompt, result in zip(data[:3], results):\n",
    "    if not isinstance(result, Exception):\n",
    "        print(prompt)\n",
    "        print(result['result'])\n",
    "        print(\"==============\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "同样的异步场景下，如果发生了错误需要进行异常处理，此时 results 中某个数据发生错误，那么对应的对象会是 Excepiton 类型，可以通过 `isinstance` 进行识别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "err_data=[\"你好\", [{\"role\":\"user\", \"content\":\"你好\"}]]\n",
    "\n",
    "results = await qianfan.ChatCompletion().abatch_do(err_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "数据0发生错误：api return error, code: 336002, msg: Invalid JSON\n",
      "数据1结果：你好，有什么我可以帮助你的吗？\n"
     ]
    }
   ],
   "source": [
    "for i, result in enumerate(results):\n",
    "    if not isinstance(result, Exception):\n",
    "        print(\"数据{}结果：{}\".format(i, result['result']))\n",
    "    else:\n",
    "        print(\"数据{}发生错误：{}\".format(i, result))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "f553a591cb5da27fa30e85168a93942a1a24c8d6748197473adb125e5473a5db"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
